Add optional Approximate Top-K configuration for MLA Indexer#4243
Merged
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
9ef80f2 to
53d7917
Compare
Collaborator
Author
|
Code quality checker failed from existing file content |
NuojCheng
reviewed
Jun 25, 2026
|
|
||
| # Assert that the actual recall is close to or exceeds the target. | ||
| # We allow a small margin (e.g., 0.05) due to the approximate nature and sample size. | ||
| self.assertGreaterEqual(mean_recall, recall_target - 0.05) |
Collaborator
There was a problem hiding this comment.
Is 0.05 too large? What about making it 0.01?
NuojCheng
approved these changes
Jun 25, 2026
NuojCheng
left a comment
Collaborator
There was a problem hiding this comment.
I think it is a very cool optimization! Thank you Jiahao
gobbleturk
approved these changes
Jun 26, 2026
e4cc351 to
76dae1b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds an optional configuration parameter
indexer_use_approx_top_kto the Multi-head Latent Attention (MLA) Indexer, allowing users to enable JAX's TPU-optimizedapprox_max_kprimitive instead of the default exacttop_kselection.Why is this change being made?
During performance investigations of DeepSeek-V3.2 in long-context mode (128K sequence length) with 128-way Context Parallelism (CP=128), the
top-kwas identified a major bottleneck in the MLA Indexer forward pass:[1, 1024, 131072]on each device.top_kselection introduces a massive step-time overhead for sorting 131K elements per layer across 58 layers.Benchmarks show a 4x speedup on f32[1,1024,65536] tensors when tested with the approximate path enabled using a recall target of 0.95.
Why this is a good solution
JAX's$\mathcal{O}(N \log^2 N)$ to $\approx \mathcal{O}(N + K \log K)$ with a significantly smaller constant factor. Paper: https://arxiv.org/pdf/2206.14286.
approx_max_kemploys block-based reduction optimized for TPU Matrix Units (MXU). It reduces complexity fromWorkload show a ~4x speedup on
f32[1, 1024, 65536]tensors when tested on TPU with the approximate path enabled using a recall target of 0.95.Specific Implementation Details
indexer_use_approx_top_kandindexer_approx_top_k_recallto theAttentionIndexerPydantic class to pass configuration validation.indexer_use_approx_top_k: false,indexer_approx_top_k_recall: 0.95).Indexer.__call__to conditionally route the selection tojax.lax.approx_max_kwhen enabled.Shortcomings & Future Improvements
indexer_use_approx_top_kinstead oftop_kmight affect downstream model performance, while it is expected to be minimal when a high recall rate is used.Tests
1. Regression Guard (Default Path)
We ran the attention unit test suite with the default configuration (
indexer_use_approx_top_k=false) to ensure no regressions:pytest tests/unit/attention_test.py2. Compilation & Tracing Safety
We added a new unit test,
test_indexer_with_approx_top_k, to verify that the new path compiles and traces successfully in JAX:pytest tests/unit/attention_test.py -k test_indexer_with_approx_top_k3. Mathematical Correctness & Recall Tracking
We added a correctness test,$K=64$ ), and calculates the actual recall:
test_approx_top_k_recall, which generates random scores of shape[4, 16, 1024], runs both exact and approximate top-K (pytest tests/unit/attention_test.py -k test_approx_top_k_recall -sChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.